"""Run .gitlab-ci.yml code through shellcheck."""

import argparse
import copy
import functools
import hashlib
import io
import json
import os
import pathlib
import re
import subprocess
import sys
import typing
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Tuple
from typing import Union

from cki_lib import config_tree
from cki_lib import yaml


class LineStr(str):
    """String that can keep track of a line number."""

    def __new__(cls, value: str,
                line: Optional[int] = None,
                filepath: Optional[pathlib.Path] = None,
                ) -> 'LineStr':
        # pylint: disable=unused-argument
        """Create a new instance of the class."""
        return str.__new__(cls, value)

    def __init__(self, value: str,
                 line: Optional[int] = None,
                 filepath: Optional[pathlib.Path] = None,
                 ) -> None:
        # pylint: disable=unused-argument
        """Initialize and keep track of the line number if provided."""
        super().__init__()
        self.line = line
        self.filepath = filepath

    def split(self, *args: Any, **kwargs: Any) -> List[str]:
        """Split and keep track of the line number."""
        if TYPE_CHECKING:
            assert self.line is not None
        return [LineStr(l, self.line + i, self.filepath)
                for i, l in enumerate(super().split(*args, **kwargs))]


class LineTrackingLoader(yaml.ReferenceLoader):
    # pylint: disable=too-many-ancestors
    """YAML loader that keeps track of source lines."""

    def construct_scalar(
        self, node: 'yaml.nodes.ScalarNode'
    ) -> Union[str, int, float, bool, None]:
        """Keep the source line on everything that looks like a string."""
        if isinstance(node.value, str):
            # - 1-based line numbers for inline and block scalars
            # - for block scalars, start_mark refers to the line of the |,
            #   while the value starts only on the next line
            increment = 2 if node.style else 1
            return LineStr(node.value, node.start_mark.line + increment)
        return super().construct_scalar(node)


class YamlShellLinter:
    # pylint: disable=too-many-instance-attributes
    """Lint the shell code in a .gitlab-ci.yml file."""

    cache_path = pathlib.Path('~/.cache/gitlab-yaml-shellcheck').expanduser()
    cache_depth = 500

    def __init__(
        self,
        file: io.TextIOWrapper,
        *,
        stdin_filename: str,
        check_sourced: bool,
        verbose: bool,
        unknown_args: list[str],
    ):
        # pylint: disable=too-many-arguments
        """Read the pipeline yaml."""
        self._cli_path = pathlib.Path(stdin_filename or file.name)
        if self._cli_path.exists():
            self._cli_path = self._cli_path.resolve()
            cli_parent_path = self._cli_path.parent
        else:
            cli_parent_path = pathlib.Path.cwd()
        with file:
            contents = file.read()
        self._cli_data, cli_config = YamlShellLinter._load(contents, self._cli_path)

        if cli_config.get('main'):
            self._main_path = cli_parent_path.joinpath(cli_config['main'][0]).resolve()
            self._main_parent_path = self._main_path.parent
            self._main_data, main_config = YamlShellLinter._load(
                self._main_path.read_text(encoding='utf8'), self._main_path)
        else:
            self._main_path, self._main_parent_path = self._cli_path, cli_parent_path
            self._main_data, main_config = self._cli_data, cli_config

        self._main_data = self._resolve_includes(self._main_data)
        self._main_data = yaml.ReferenceList.resolve_references(self._main_data)
        self._check_sourced = check_sourced
        self._verbose = verbose
        self._unknown_args = unknown_args
        self._predefined = [p for p in
                            ','.join(main_config.get('predefined', [])).split(',')
                            if p]

        self._filesystem_cache: Dict[str, Tuple[int, List[str]]] = {}
        self._memory_cache: Dict[str, Tuple[int, List[str]]] = {}

    @staticmethod
    def _load_config(contents: str) -> typing.Dict[str, List[str]]:
        config: typing.Dict[str, List[str]] = {}
        for match in re.finditer(r'^#\s*gitlab-yaml-shellcheck:((?:\s+[^\s]+=[^\s]+)+)',
                                 contents, re.MULTILINE):
            for key_value in match.group(1).split():
                key, value = key_value.split('=')
                config.setdefault(key, []).append(value)
        return config

    @staticmethod
    def _load(data: str, path: pathlib.Path) -> typing.Any:
        parsed_data = yaml.load(contents=data, loader=LineTrackingLoader, load_all=True)[-1]
        YamlShellLinter._add_filepath(parsed_data, path)
        config = YamlShellLinter._load_config(data)
        return parsed_data, config

    @staticmethod
    def _add_filepath(data: typing.Any, filepath: pathlib.Path) -> None:
        """Add the source file name to LineStr instances."""
        if isinstance(data, dict):
            for value in data.values():
                YamlShellLinter._add_filepath(value, filepath)
        elif isinstance(data, list):
            for value in data:
                YamlShellLinter._add_filepath(value, filepath)
        elif isinstance(data, LineStr) and data.filepath is None:
            data.filepath = filepath

    def _resolve_includes(self,
                          data: typing.Mapping[str, typing.Any]
                          ) -> typing.Dict[str, typing.Any]:
        """Resolve local include references."""
        result: typing.Dict[typing.Any, typing.Any] = {}
        includes = data.get('include') or []
        if not isinstance(includes, list):
            includes = [includes]
        for include in includes:
            if isinstance(include, str):
                include = {'local': include}
            if 'local' not in include:
                raise Exception('Non-local includes are not supported')
            # GitLab has a different idea what ** means 🤦
            glob_pattern = str(include['local']).replace('/**', '/**/*')
            paths = list(self._main_parent_path.glob(glob_pattern))
            if not paths:
                raise Exception(f'Include not found: {glob_pattern}')
            for path in paths:
                if path.resolve() == self._cli_path:  # use the contents provided on the CLI
                    include_data = self._cli_data
                else:
                    include_data, _ = YamlShellLinter._load(path.read_text(encoding='utf8'), path)
                config_tree.merge_dicts(result, include_data)
        config_tree.merge_dicts(result, data)
        return result

    @staticmethod
    def merge(results: List[Tuple[int, List[str]]]) -> Tuple[int, List[str]]:
        """Sensibly merge multiple tuples of (returncode, messages).

        Returns a tuple of (highest returncode, messages).
        """
        return functools.reduce(
            lambda a, b: (max(a[0], b[0]), list(set(a[1]) | set(b[1]))), results)

    def all_job_names(self) -> List[str]:
        """Return a list of the names of all jobs."""
        return [k for k in self._main_data.keys()
                if isinstance(self._main_data[k], dict)]

    def _merge_extends(self, key: str) -> Dict[str, Any]:
        result = {}
        extends = self._main_data[key].get('extends')
        if extends:
            if not isinstance(extends, list):
                extends = [extends]
            for extend in extends:
                result.update(self._merge_extends(extend))
        result.update(copy.deepcopy(self._main_data[key]))
        result.pop('extends', None)
        return result

    def lint_job(self, job_name: str) -> Tuple[int, List[str]]:
        """Run the linter on the given jobs of the yaml.

        Returns a tuple of (highest returncode, messages).
        """
        job = self._merge_extends(job_name)
        variables = self._get_dict((job, self._main_data), 'variables')
        before_script = self._clean_script(
            self._get_array((job, self._main_data), 'before_script'))
        script = self._clean_script(
            self._get_array((job, self._main_data), 'script'))
        after_script = self._clean_script(
            self._get_array((job, self._main_data), 'after_script'))

        exports = [f'export {v}="{{non,word}}* [characters]"'
                   for v in list(variables.keys()) + self._predefined]
        common = (
            ['#!/bin/bash'] +
            ['### setup'] +
            ['set -eo pipefail'] +  # gitlab-runner/shells/bash.go:writeScript
            ['### variables'] + exports
        )

        results = [self._lint_script(f'{job_name}.script', (
            common +
            ['### before_script'] + before_script +
            ['### script'] + script +
            ['### end']
        ))]
        if after_script:
            results += [self._lint_script(f'{job_name}.after_script', (
                common +
                ['### after_script'] + after_script +
                ['### end']
            ))]
        return self.merge(results)

    def load_cache(self) -> None:
        """Load old shellcheck results from the filesystem."""
        for path in self.cache_path.glob('*.json'):
            self._filesystem_cache[path.stem] = json.loads(
                path.read_bytes())

    def save_cache(self) -> None:
        """Save shellcheck results to the filesystem."""
        self.cache_path.mkdir(parents=True, exist_ok=True)
        for key, value in self._memory_cache.items():
            self.cache_path.joinpath(f'{key}.json').write_text(json.dumps(value), encoding='utf8')
        paths = sorted(self.cache_path.iterdir(), key=os.path.getmtime)
        for path in paths[0:-self.cache_depth]:
            path.unlink()

    @staticmethod
    def _cache_key(options: List[str], script: List[str]) -> str:
        """Determine a unique cache key for the script.

        In the cache key, include all aspects that might have influence on the
        shellcheck output:
        - the script contents
        - the basename and line numbers of the script
        - the shellcheck options
        """
        option_lines = [' '.join(options)]
        script_lines = [':'.join([
            line,
            (getattr(line, "filepath", pathlib.Path('unknown')) or pathlib.Path('unknown')).stem,
            str(getattr(line, "line", -1)),
        ]) for line in script]
        concatenated = '\n'.join(option_lines + script_lines).encode('utf-8')
        return hashlib.sha256(concatenated).hexdigest()

    def _get_cache(self, options: List[str], script: List[str]
                   ) -> Optional[Tuple[int, List[str]]]:
        digest = self._cache_key(options, script)
        if digest in self._memory_cache:
            return self._memory_cache[digest]
        if digest in self._filesystem_cache:
            return self._filesystem_cache[digest]
        return None

    def _put_cache(self, options: List[str], script: List[str], result:
                   Tuple[int, List[str]]) -> Tuple[int, List[str]]:
        digest = self._cache_key(options, script)
        self._memory_cache[digest] = result
        return result

    def _lint_script(self, job: str, script: List[str]) -> Tuple[int, List[str]]:
        cmd = ['shellcheck', '-f', 'gcc', '-']
        if self._check_sourced:
            cmd += ['-ax']
        else:
            cmd += ['-x']
        options = self._unknown_args

        cached = self._get_cache(options, script)
        if cached:
            return cached

        if self._verbose:
            print(f'Linting {job} via {" ".join(cmd + options)}')
        output = subprocess.run(cmd + options,
                                input='\n'.join(script),
                                encoding='utf8',
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                check=False)
        messages = [self._mangle_shellcheck_output(script, line)
                    for line in output.stdout.split('\n')
                    if line]
        return self._put_cache(options, script, (output.returncode, messages))

    def _mangle_shellcheck_output(self, script: List[str], line: str) -> str:
        parts = line.split(':', 3)
        if len(parts) < 4:
            return line
        filename, line, column, message = parts[:4]
        if filename == '-':
            column = '1'
            if int(line) in range(len(script)):
                line_str = script[int(line) - 1]
                line = str(getattr(line_str, 'line', -1))
                filename = str(getattr(line_str, 'filepath', self._cli_path) or self._cli_path)
            else:
                line = '-1'
                filename = str(self._cli_path)
        return ':'.join([filename, line, column, message])

    @staticmethod
    def _get_array(yaml_dicts: Tuple[Dict[str, List[str]], Dict[str, List[str]]],
                   field_name: str) -> List[str]:
        """Retrieve an array from a field of different dicts of the yaml."""
        for part in yaml_dicts:
            if field_name in part:
                return part[field_name]
        return []

    @staticmethod
    def _get_dict(yaml_dicts: Tuple[Dict[str, Any], Dict[str, Any]],
                  field_name: str) -> Dict[str, Any]:
        """Assemble a dict from a field of different dicts of the yaml."""
        result = {}
        for part in reversed(yaml_dicts):
            if field_name in part:
                result.update(part[field_name])
        return result

    @staticmethod
    def _clean_script(parts: Union[str, List[str]]) -> List[str]:
        """Flatten a yaml script array."""
        if isinstance(parts, str):
            return parts.split('\n')
        return [line for part in parts for line in YamlShellLinter._clean_script(part)]


def main(argv: Optional[List[str]] = None) -> int:
    """Run the command line interface."""
    parser = argparse.ArgumentParser()
    parser.add_argument('file', type=argparse.FileType('r'),
                        help='GitLab CI/CD pipeline file')
    parser.add_argument('--job', action='append',
                        help='Only lint the given jobs')
    parser.add_argument('--check-sourced', action='store_true',
                        help='show warnings from followed "source" commands')
    parser.add_argument('--stdin-filename', default='',
                        help='Use the given file name for the contents of standard input')
    parser.add_argument('--verbose', action='store_true',
                        help='provide progress information')
    parser.add_argument('--no-cache', action='store_true',
                        help='do not cache shellcheck output between calls')
    args, unknown_args = parser.parse_known_args(argv or sys.argv[1:])

    linter = YamlShellLinter(
        args.file,
        stdin_filename=args.stdin_filename,
        check_sourced=args.check_sourced,
        verbose=args.verbose,
        unknown_args=unknown_args,
    )
    if not args.check_sourced and not args.no_cache:
        linter.load_cache()
    job_names = args.job or linter.all_job_names()
    returncode, messages = linter.merge([linter.lint_job(job_name) for job_name
                                         in job_names])
    if not args.check_sourced and not args.no_cache:
        linter.save_cache()

    print('\n'.join(messages))
    return returncode


if __name__ == '__main__':
    sys.exit(main())
